import hydra
import torch
from omegaconf import DictConfig, OmegaConf

from Initializers.data import initialize_data
from Initializers.init_utils import init_logistics, init_loading
from Initializers.initialize_models import initialize_models
from Initializers.wandb_logging_policy import initialize_wandb
from Policy.gcrl_trainer import GCTrainer


@hydra.main(version_base=None, config_path="./configs", config_name="config")
def train_HRL(config: DictConfig):
    config = OmegaConf.structured(OmegaConf.to_yaml(config))
    # recovers the hydra config, which reads from ./configs/config.yaml (modify the hydra parameters)

    wdb_run = initialize_wandb(config) # the wandb config specifies the name from record.wandb_log
    # create the environment and assign the number of variables and observation shape
    single_env, train_env, test_env, norm, logger, config = init_logistics(config, wdb_run=wdb_run)


    # Similar to code in tianshou.examples
    config.device = torch.device(f"cuda:{config.cuda_id}" if torch.cuda.is_available() else "cpu")

    # create the models, this is the most involved part of initialization
    dynamics, graph_encoding, reward, policy = initialize_models(config, single_env, norm, wdb_run)
    
    # initializes a vectored buffer (for the number of environments),
    # and the train and test collectors (handling the action-observation-reward loop)
    train_collector, test_collector, buffer = \
        initialize_data(config, policy, dynamics, single_env, train_env, test_env, norm)

    # loads a buffer from memory if necessary
    buffer = init_loading(config,
                            dynamics,
                            graph_encoding,
                            policy,
                            buffer)

    # Need to assign the buffer to the collector
    train_collector.buffer = buffer

    # the below functions are trained in the trainer loop (collection-policy update)
    def save_best_fn(policy):
        torch.save(policy.state_dict(), config.exp_path / "policy.pth")

    def stop_fn(mean_rewards):
        return False

    def train_fn(epoch, env_step):
        if env_step <= config.train.dynamics_warmup_step:
            train_collector.random_all = True
            policy.use_update = False
        else:
            train_collector.random_all = False
            policy.use_update = True

    def test_fn(epoch, env_step):
        # This is used for naming the gif file
        test_collector.epoch_num = epoch

    def save_checkpoint_fn(epoch, env_step, gradient_step):
        if epoch == 0 or epoch % config.save.save_freq != 0:
            return None
        print("saving checkpoints")
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
        ckpt_path = config.exp_path / f"policy_{epoch}.pth"
        torch.save(policy.state_dict(), ckpt_path)

        if config.save.save_replay_buffer:
            print("saving replaybuffer")
            buffer_path = config.replay_buffer_dir / "buffer.hdf5"
            train_collector.buffer.save_hdf5(buffer_path)
        return ckpt_path

    if config.load.load_rpb:
        print("skipping initial data collection because of replay buffer loading")
    elif config.train.init_random_step > 0:
        # start filling replay buffer with random actions

        train_collector.collect(n_step=config.train.init_random_step, random=True)
        print("finished initial data collection")

    # the tianshou trainer .run() function acts like an iterator, repeatedly calling the
    # __next__(), which alternates collector.collect and GCTrainer.policy_update_fn
    # with logging interspersed throughout
    result = GCTrainer(
        policy,
        dynamics,
        train_collector,
        test_collector,
        config.train.epoch,
        config.train.env_step_per_epoch,
        config.train.env_step_per_collect,
        config.train.policy_update_per_env_step,
        config.train.dynamics_update_per_env_step,
        config.train.policy_batch_size,
        config.train.dynamics_batch_size,
        config.train.test_ep_per_epoch,
        train_fn=train_fn,
        dynamics_warmup_step=config.train.dynamics_warmup_step,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        save_checkpoint_fn=save_checkpoint_fn,
        # resume_from_log=config.load.resume_from_log,  # Todo: this requires setting the logger correctly
        logger=logger,
        fpg=config.train.fpg,
        n_steps_per_goal=config.train.n_steps_per_goal
    ).run()
    print("final result", result)


if __name__ == "__main__":
    train_HRL()
